---
title: "Revise LR model"
author: "Rachel Smith"
date: "`r Sys.Date()`"
output: pdf_document
---

Code to identify the best fitting model for each of cerebral Hb_tot and cerebral Hb_tot alpha

```{r setup, include = FALSE}

knitr::opts_chunk$set(echo = FALSE, message = FALSE, warning = FALSE,
                      fig.height = 5.5, fig.width = 8)

# libraries
library(tidyverse)
library(RColorBrewer)
library(gridExtra)
library(patchwork)
library(Hmisc)
library(corrr)
library(tidymodels)
library(themis)
library(vip)
library(janitor)
library(kableExtra)

# set theme for plots
theme_set(theme_bw() +
            theme(plot.title = element_text(size = 18),
                  axis.title = element_text(size = 15),
                  axis.text = element_text(size = 12),
                  strip.text = element_text(size = 12),
                  legend.title = element_text(size = 15),
                  legend.text = element_text(size = 12)))

# create color vector
my_col <- c("CM" = "#00BFC4", "HC" = "#7CAE00", "UM" = "#F8766D")

```

```{r data, echo = FALSE}

# load
df_hbox <- read_csv("Data/Hans_datatable_exports/malawi key data v04Jan2022.csv") %>% 
  as_tibble() %>% 
  janitor::clean_names() %>% 
  mutate(status = replace(status, status == "HV", "HC")) %>% 
  dplyr::rename("subject_id" = "subject_id_session") %>% 
  
  dplyr::filter(status %in% c("CM", "HC", "UM") & session_no == 1 | subject_id == "TM0003CM01") %>% 
  dplyr::filter(!grepl("blood", subject_id)) %>% 
  mutate(status = factor(status, levels = c("HC", "UM", "CM"))) %>% 
  mutate_at(c("bp_dias", "bp_sys", "glucose", "lactate"), as.numeric) %>% # converts remaining character variables to numerics
  select(subject_id, status, everything()) %>% 
  group_by(status)

# format
  # define clinical variables of interest
  clinical_voi <- c("age_calc", "temperature", "hr", "bp_sys", "bp_dias", "resp_rate", "pulse_ox", "hct", "lactate", "glucose", "aa_arg")
  hbox_voi <- c("cerebral_hb_tot", "cerebral_hb_tot_alpha2")
  
  # select clinical VOI
  df_model <- df_hbox %>% 
    dplyr::select(c(status, sex, all_of(hbox_voi), all_of(clinical_voi))) %>% 
    mutate(sex_coded = ifelse(sex == "Female", 1, 0)) %>% 
    dplyr::select(-sex)
  
  # impute missing data with median by group
  f_impute <- function(x, na.rm = TRUE) (replace(x, is.na(x), median(x, na.rm = na.rm)))
  
  df_model_impute <- df_model %>% 
    group_by(status) %>% 
    mutate_at(2:ncol(.), f_impute) %>% 
    ungroup() 

```

## Chi-squared test for differences in proportion of sex by patient group

HC, UM, CM
```{r sex_chisq_all}

# HC, UM, CM
df_sex_chisq_all <- df_hbox %>% 
  dplyr::select(status, sex) %>% 
  count(status, sex) %>% 
  pivot_wider(id_cols = status, names_from = sex, values_from = n) %>% 
  as.data.frame %>% 
  column_to_rownames("status")
  
knitr::kable(df_sex_chisq_all)

chisq.test(df_sex_chisq_all)

```

UM vs CM
```{r sex_chisq_um_cm}

# UM vs CM
df_sex_chisq_um_cm <- df_hbox %>% 
  dplyr::select(status, sex) %>% 
  filter(status != "HC") %>% 
  count(status, sex) %>% 
  pivot_wider(id_cols = status, names_from = sex, values_from = n) %>% 
  as.data.frame %>% 
  column_to_rownames("status")
  
knitr::kable(df_sex_chisq_um_cm)

chisq.test(df_sex_chisq_um_cm)

```

```{r lr_model}

# write function to run LR model
f_run_lr <- function(df_model, resamples) {
  
  # resample
  hbox_boot <- bootstraps(df_model, strata = status, times = resamples)
  
  # recipe
  hbox_rec <- recipe(status ~ ., data = df_model) %>% 
    step_impute_knn(all_predictors()) %>% 
    step_normalize(all_predictors()) %>%
    step_zv(all_predictors()) %>%
    step_smote(status) %>% 
    update_role(subject_id, new_role = "id")
  
  # create workflow
  hbox_wf <- workflow() %>% add_recipe(hbox_rec)
  
  # logistic regression model specifications
  glm_spec <- logistic_reg() %>% set_engine("glm")
  
  # create model on 5000 different resamples
  glm_res <- hbox_wf %>% 
    add_model(glm_spec) %>%
    fit_resamples(
      resamples = hbox_boot,
      metrics = metric_set(roc_auc, accuracy, sensitivity, specificity),
      control = tune::control_resamples(save_pred = TRUE, verbose = TRUE)
    )
  
  # select best model to be used for evaluation
  glm_best <- glm_res %>% select_best("roc_auc")
  
  # finalize model
  hbox_wf %>% 
    add_model(glm_spec) %>% 
    finalize_workflow(glm_best) %>%
    fit(df_model) %>%
    extract_fit_parsnip()
  
  
}

```

## Shapiro-Wilk test for normality
```{r shapiro_wilk}

voi <- c("cerebral_hb_tot", "cerebral_hb_tot_alpha2", "sex", "resp_rate", "hct", "lactate")

df_hbox_voi <- df_hbox %>% 
  select(status, all_of(voi)) %>% 
  dplyr::select(-sex) %>% 
  pivot_longer(2:ncol(.), names_to = "variable", values_to = "value")

# plot distributions
df_hbox_voi %>%  
  ggplot(aes(x = status, y = value)) +
  geom_violin(aes(color = status)) +
  geom_jitter(position = position_jitter(0.15), shape = 1, size = 1.5) +
  stat_summary(fun = "median", geom = "crossbar", aes(color = status), size = 0.2, width = 0.5) +
  facet_wrap(~ variable, scales = "free") +
  labs(x = "", y = "") +
  ggtitle("Clinical variable distributions by patient group") +
  theme(legend.position = "none")

# Shapiro-Wilk test
df_shapiro <- df_hbox_voi %>% 
  group_by(status, variable) %>% 
  nest() %>% 
  
  mutate(shapiro_test = map(.x = data,
                            .f = ~ shapiro.test(.x$value))) %>% 
  mutate(p_val = map(.x = shapiro_test,
                     .f = ~ .x$p.value)) %>% 
  unnest(cols = c(p_val)) %>% 
  mutate(p_adj = p.adjust(p_val, method = "BH")) %>% 
  dplyr::select(status, variable, p_adj) %>% 
  pivot_wider(id_cols = status, names_from = variable, values_from = p_adj) %>% 
  mutate_if(is.numeric, ~ round(.x, digits = 3))
  
df_shapiro %>% 
  knitr::kable()

```

## Cerebral Hb_tot model fitting

status ~ log10(cerebral_hb_tot) + sex + log10(resp_rate) + hct + log10(lactate)
```{r hb_tot_model1}

df_model <- df_hbox %>% 
  mutate(sex = ifelse(sex == "Female", 1, 0)) %>% 
  filter(status != "HC") %>% 
  mutate(status = factor(status, levels = c("UM", "CM"))) %>% 
  
  # transform non-normal variables (by Shapiro-Wilk) using log-transform
  mutate(log10_cerebral_hb_tot = log10(cerebral_hb_tot),
         log10_resp_rate = log10(resp_rate),
         log10_lactate = log10(lactate))

set.seed(20221125)

# 1. status ~ log10_cerebral_hb_tot + sex + log10_resp_rate + hct + log10_lactate
hbtot_final1 <- f_run_lr(df_model %>% 
                           dplyr::select(subject_id, status, log10_cerebral_hb_tot, 
                                         sex, log10_resp_rate, hct, log10_lactate), 
                       10)

hbtot_final1 %>% 
  tidy %>% 
  filter(term != "(Intercept)") %>% 
  mutate_if(is.numeric, ~ round(.x, 3)) %>% 
  knitr::kable() # AIC 82.61

```

status ~ log10(cerebral_hb_tot) + sex + hct + log10(lactate)
```{r hb_tot_model2}

# 2. status ~ log10_cerebral_hb_tot + sex + hct + log10_lactate

set.seed(20221125)

# resample
hbtot_boot <- bootstraps(df_model, strata = status, times = 5000)

# recipe
hbtot_rec <- recipe(status ~ log10_cerebral_hb_tot + hct + log10_lactate + sex + subject_id, data = df_model) %>% 
  step_impute_knn(all_predictors()) %>% 
  step_normalize(all_predictors()) %>%
  step_zv(all_predictors()) %>%
  step_smote(status) %>% 
  update_role(subject_id, new_role = "id")

# create workflow
hbtot_wf <- workflow() %>% add_recipe(hbtot_rec)

# logistic regression model specifications
hbtot_glm_spec <- logistic_reg() %>% set_engine("glm")

# create model on 5000 different resamples
hbtot_glm_res <- hbtot_wf %>% 
  add_model(hbtot_glm_spec) %>%
  fit_resamples(
    resamples = hbtot_boot,
    metrics = metric_set(roc_auc, accuracy, sensitivity, specificity),
    control = tune::control_resamples(save_pred = TRUE, verbose = TRUE)
  )

# select best model to be used for evaluation
hbtot_glm_best <- hbtot_glm_res %>% select_best("roc_auc")

# finalize model
hbtot_final2 <- hbtot_wf %>% 
  add_model(hbtot_glm_spec) %>% 
  finalize_workflow(hbtot_glm_best) %>%
  fit(df_model) %>%
  extract_fit_parsnip()

hbtot_final2 %>% 
  tidy %>% 
  filter(term != "(Intercept)") %>% 
  mutate_if(is.numeric, ~ round(.x, 3)) %>% 
  knitr::kable() # AIC 81.92, every term is significant

```

```{r assess_hb_tot_model}

hbtot_final2$fit
# AIC 76.25

df_CI <- confint(hbtot_final2$fit) %>% exp() %>% as_tibble(rownames = "term")

hbtot_final2 %>%
  tidy(exponentiate = TRUE) %>%
  filter(p.value < 0.05) %>% 
  left_join(df_CI, by = "term") %>%
  mutate(estimate = ifelse(estimate < 10^-2, scientific(estimate, digits = 3), round(estimate, 2)),
         conf.low = ifelse(`2.5 %` < 10^-2, scientific(`2.5 %`, digits = 3), round(`2.5 %`, 2)),
         conf.high = ifelse(`97.5 %` < 10^-2, scientific(`97.5 %`, digits = 3), round(`97.5 %`, 2))) %>%
  mutate(p.value = scientific(p.value, digits = 3),
         statistic = round(statistic, 2)) %>%
  unite("95% CI", c(conf.low, conf.high), sep = ", ") %>%
  dplyr::select(c(term, estimate, `95% CI`, statistic, p.value)) %>% 
  dplyr::rename("Term" = "term", "Odds ratio" = "estimate",
                "t-statistic" = "statistic", "p-value" = "p.value") %>%
  filter(Term != "(Intercept)") %>% 
  as.data.frame() %>% 
  
  knitr::kable(align = rep('c', 5))

hbtot_glm_res %>% 
  collect_metrics() %>% 
  dplyr::select(.metric, mean, std_err) %>% 
  mutate(mean = mean*100, std_err = std_err*100) %>% 
  mutate(Metric = c("Accuracy", "ROC AUC", "Sensitivity", "Specificity")) %>% 
  dplyr::select(c("Metric", "mean", "std_err")) %>% 
  
  dplyr::rename("Mean" = "mean", "Standard error" = "std_err") %>% 
  mutate_if(is.numeric, round, 2) %>% 
  as.data.frame() %>% 
  
  knitr::kable()

```

```{r hb_tot_imp, fig.height = 4}

set.seed(123)
hbtot_imp <- hbtot_wf %>% 
  add_model(hbtot_glm_spec) %>% 
  fit(df_model) %>% 
  pull_workflow_fit() %>% 
  
  vi(
    method = "permute", nsim = 10,
    train = juice(hbtot_rec %>% prep()),
    target = "status", metric = "auc", reference_class = "CM",
    pred_wrapper = kernlab::predict
    
  )

hbtot_imp %>% 
  clean_names %>% 
  filter(variable != "subject_id") %>% 
  mutate(variable = fct_reorder(variable, importance)) %>% 
  
  ggplot(aes(x = importance, y = variable, color = variable)) +
  geom_errorbar(aes(xmin = importance - st_dev, xmax = importance + st_dev),
                alpha = 0.5, size = 1.3) +
  geom_point(size = 3) +
  theme(legend.position = "none",
        plot.title = element_text(size = 18),
        axis.title = element_text(size = 15),
        axis.text = element_text(size = 12),
        strip.text = element_text(size = 12),
        legend.title = element_text(size = 15),
        legend.text = element_text(size = 12)) +
  labs(y = NULL, x = "relative importance") +
  ggtitle("Variable importance in distinguishing CM vs UM")

```

## Cerebral Hb_tot_alpha model fitting

status ~ cerebral_hb_tot_alpha2 + sex + log10(resp_rate) + hct + log10(lactate)
```{r hb_tot_alpha_model1}

# 1. status ~ cerebral_hb_tot_alpha2 + sex + log10_resp_rate + hct + log10_lactate
hbtot_alpha_final1 <- f_run_lr(df_model %>% 
                           dplyr::select(subject_id, status, cerebral_hb_tot_alpha2, 
                                         sex, log10_resp_rate, hct, log10_lactate), 
                       10)

hbtot_alpha_final1 %>% 
  tidy %>% 
  filter(term != "(Intercept)") %>% 
  mutate_if(is.numeric, ~ round(.x, 3)) %>% 
  knitr::kable() # AIC 62.41

```

status ~ cerebral_hb_tot_alpha2 + sex + hct + log10(lactate)
```{r hb_tot_alpha_model2}

# 2. status ~ cerebral_hb_tot_alpha2 + sex + hct + log10_lactate

set.seed(20221125)

# resample
hbtot_alpha_boot <- bootstraps(df_model, strata = status, times = 5000)

# recipe
hbtot_alpha_rec <- recipe(status ~ cerebral_hb_tot_alpha2 + hct + sex + log10_lactate + subject_id, data = df_model) %>% 
  step_impute_knn(all_predictors()) %>% 
  step_normalize(all_predictors()) %>%
  step_zv(all_predictors()) %>%
  step_smote(status) %>% 
  update_role(subject_id, new_role = "id")

# create workflow
hbtot_alpha_wf <- workflow() %>% add_recipe(hbtot_alpha_rec)

# logistic regression model specifications
hbtot_alpha_glm_spec <- logistic_reg() %>% set_engine("glm")

# create model on 5000 different resamples
hbtot_alpha_glm_res <- hbtot_alpha_wf %>% 
  add_model(hbtot_alpha_glm_spec) %>%
  fit_resamples(
    resamples = hbtot_alpha_boot,
    metrics = metric_set(roc_auc, accuracy, sensitivity, specificity),
    control = tune::control_resamples(save_pred = TRUE, verbose = TRUE)
  )

# select best model to be used for evaluation
hbtot_alpha_glm_best <- hbtot_alpha_glm_res %>% select_best("roc_auc")

# finalize model
hbtot_alpha_final2 <- hbtot_alpha_wf %>% 
  add_model(hbtot_alpha_glm_spec) %>% 
  finalize_workflow(hbtot_alpha_glm_best) %>%
  fit(df_model) %>%
  extract_fit_parsnip()

hbtot_alpha_final2 %>% 
  tidy %>% 
  filter(term != "(Intercept)") %>% 
  mutate_if(is.numeric, ~ round(.x, 3)) %>% 
  knitr::kable() # AIC 65.29

## save final models!
save(hbtot_final2, hbtot_alpha_final2,
     file = "objects/20221125_final_glm_models.Rdata")

```

```{r assess_hb_tot_alpha_model}

hbtot_alpha_final2$fit
# AIC 61.05

df_CI <- confint(hbtot_alpha_final2$fit) %>% exp() %>% as_tibble(rownames = "term")

hbtot_alpha_final2 %>%
  tidy(exponentiate = TRUE) %>%
  filter(p.value < 0.05) %>% 
  left_join(df_CI, by = "term") %>%
  mutate(estimate = ifelse(estimate < 10^-2, scientific(estimate, digits = 3), round(estimate, 2)),
         conf.low = ifelse(`2.5 %` < 10^-2, scientific(`2.5 %`, digits = 3), round(`2.5 %`, 2)),
         conf.high = ifelse(`97.5 %` < 10^-2, scientific(`97.5 %`, digits = 3), round(`97.5 %`, 2))) %>%
  mutate(p.value = scientific(p.value, digits = 3),
         statistic = round(statistic, 2)) %>%
  unite("95% CI", c(conf.low, conf.high), sep = ", ") %>%
  dplyr::select(c(term, estimate, `95% CI`, statistic, p.value)) %>% 
  dplyr::rename("Term" = "term", "Odds ratio" = "estimate",
                "t-statistic" = "statistic", "p-value" = "p.value") %>%
  filter(Term != "(Intercept)") %>% 
  as.data.frame() %>% 
  
  knitr::kable(align = rep('c', 5))

hbtot_alpha_glm_res %>% 
  collect_metrics() %>% 
  dplyr::select(.metric, mean, std_err) %>% 
  mutate(mean = mean*100, std_err = std_err*100) %>% 
  mutate(Metric = c("Accuracy", "ROC AUC", "Sensitivity", "Specificity")) %>% 
  dplyr::select(c("Metric", "mean", "std_err")) %>% 
  
  dplyr::rename("Mean" = "mean", "Standard error" = "std_err") %>% 
  mutate_if(is.numeric, round, 2) %>% 
  as.data.frame() %>% 
  
  knitr::kable()

```

```{r hb_tot_alpha_imp, fig.height = 4}

set.seed(123)
hbtot_alpha_imp <- hbtot_alpha_wf %>% 
  add_model(hbtot_alpha_glm_spec) %>% 
  fit(df_model) %>% 
  pull_workflow_fit() %>% 
  
  vi(
    method = "permute", nsim = 10,
    train = juice(hbtot_alpha_rec %>% prep()),
    target = "status", metric = "auc", reference_class = "CM",
    pred_wrapper = kernlab::predict
    
  )

hbtot_alpha_imp %>% 
  clean_names %>% 
  filter(variable != "subject_id") %>% 
  mutate(variable = fct_reorder(variable, importance)) %>% 
  
  ggplot(aes(x = importance, y = variable, color = variable)) +
  geom_errorbar(aes(xmin = importance - st_dev, xmax = importance + st_dev),
                alpha = 0.5, size = 1.3) +
  geom_point(size = 3) +
  theme(legend.position = "none",
        plot.title = element_text(size = 18),
        axis.title = element_text(size = 15),
        axis.text = element_text(size = 12),
        strip.text = element_text(size = 12),
        legend.title = element_text(size = 15),
        legend.text = element_text(size = 12)) +
  labs(y = NULL, x = "relative importance") +
  ggtitle("Variable importance in distinguishing CM vs UM")

```



